{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "RSR datasets",
      "provenance": [],
      "private_outputs": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l6f8HuBbhqqE",
        "colab_type": "text"
      },
      "source": [
        "Licensed under the Apache License, Version 2.0\n",
        "\n",
        "This python notebook shows how to reproduce and inspect the datasets and splits used in `In-Domain Representation Learning for Remote Sensing'\n",
        "by Maxim Neumann, André Susano Pinto, Xiaohua Zhai and Neil Houlsby. Pre-print\n",
        "available in [arXiv](https://arxiv.org/abs/1911.06721).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k1Zss9bIiLT6",
        "colab_type": "text"
      },
      "source": [
        "### Imports and code to define datasets"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1_Mxo4AFhWNk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "import tqdm\n",
        "\n",
        "import matplotlib.pylab as plt\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "tf.enable_v2_behavior()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bCPoemuBisAq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "TRAIN_SPLIT_PERCENT = 60\n",
        "VALIDATION_SPLIT_PERCENT = 20\n",
        "TEST_SPLIT_PERCENT = 20\n",
        "SO2SAT_VALIDATION_SUBSPLIT_PERCENT = 25\n",
        "\n",
        "class TfdsDataset:\n",
        "  def __init__(self, dataset_name, params={}, **kwargs):\n",
        "    is_so2sat = dataset_name.startswith(\"so2sat\")\n",
        "    is_bigearthnet = dataset_name.startswith(\"bigearthnet\")\n",
        "\n",
        "    self.name = dataset_name\n",
        "    self.is_multilabel = is_bigearthnet\n",
        "    self.builder = tfds.builder(dataset_name, **kwargs)\n",
        "\n",
        "    self.label_key = \"labels\" if is_bigearthnet else \"label\"\n",
        "    self.image_key = \"image\"\n",
        "    self.filename_key = \"sample_id\" if is_so2sat else \"filename\"\n",
        "\n",
        "    self._shuffle_buffer_size = params.get(\"shuffle_buffer_size\", 10000)\n",
        "    self._num_parallel_calls = params.get(\"num_preprocessing_threads\", 100)\n",
        "    self._drop_remainder = params.get(\"drop_remainder\", True)\n",
        "    self._ignore_errors = params.get(\"ignore_errors\", False)\n",
        "    self._prefetch = params.get(\"prefetch\", 1)\n",
        "\n",
        "    self.label_name_fn = self.builder.info.features[self.label_key].int2str\n",
        "    self.num_classes = self.builder.info.features[self.label_key].num_classes\n",
        "\n",
        "    if is_so2sat:\n",
        "      self._tfds_splits = dict(\n",
        "          train=f\"train\",\n",
        "          val=f\"validation[:{SO2SAT_VALIDATION_SUBSPLIT_PERCENT}%]\",\n",
        "          test=f\"validation[{SO2SAT_VALIDATION_SUBSPLIT_PERCENT}%:]\")\n",
        "      val_count = self.builder.info.splits[tfds.Split.VALIDATION].num_examples\n",
        "      self._num_samples_splits = dict(\n",
        "          train=self.builder.info.splits[tfds.Split.TRAIN].num_examples,\n",
        "          val=val_count * SO2SAT_VALIDATION_SUBSPLIT_PERCENT // 100,\n",
        "          test=val_count * (100-SO2SAT_VALIDATION_SUBSPLIT_PERCENT) // 100)\n",
        "    else:\n",
        "      self._tfds_splits = dict(\n",
        "          train=f\"train[:{TRAIN_SPLIT_PERCENT}%]\",\n",
        "          val=f\"train[{TRAIN_SPLIT_PERCENT}%:{TRAIN_SPLIT_PERCENT+VALIDATION_SPLIT_PERCENT}%]\",\n",
        "          test=f\"train[{TRAIN_SPLIT_PERCENT+VALIDATION_SPLIT_PERCENT}%:]\")\n",
        "      num_examples = self.builder.info.splits[tfds.Split.TRAIN].num_examples\n",
        "      self._num_samples_splits = dict(\n",
        "          train=num_examples * TRAIN_SPLIT_PERCENT // 100,\n",
        "          val=num_examples * VALIDATION_SPLIT_PERCENT // 100,\n",
        "          test=num_examples * TEST_SPLIT_PERCENT // 100)\n",
        "      \n",
        "  def _get_deterministic_dataset(self, split_name, for_eval, train_examples):\n",
        "    \"\"\"Creates a tf.data.Dataset composed of a deterministic set of examples.\"\"\"\n",
        "    # Don't shuffle to receive exactly the same split for reproducibility.\n",
        "    dataset = self.builder.as_dataset(\n",
        "        split=self._tfds_splits[split_name],\n",
        "        shuffle_files=False,\n",
        "        decoders={self.image_key: tfds.decode.SkipDecoding()})\n",
        "    num_samples = self._num_samples_splits[split_name]\n",
        "\n",
        "    if not for_eval and train_examples:\n",
        "      dataset = dataset.take(train_examples)\n",
        "      num_samples = train_examples\n",
        "\n",
        "    return dataset, num_samples\n",
        "\n",
        "  def get_filenames(self, split_name, train_examples=None, for_eval=False):\n",
        "    dataset, num_samples = self._get_deterministic_dataset(split_name, for_eval, train_examples)\n",
        "    def _get(example):\n",
        "      fname = example[self.filename_key].numpy()\n",
        "      if np.issubdtype(type(fname), np.signedinteger):\n",
        "        fname = bytes(str(fname), encoding=\"utf-8\")\n",
        "      return fname\n",
        "    return list([_get(x) for x in dataset])\n",
        "    \n",
        "  def get_tf_data(self, split_name, batch_size, preprocess_fn=None,\n",
        "                  for_eval=False, train_examples=None, epochs=None):\n",
        "    \"\"\"Creates a tf.data.Dataset with features (label, image, filename).\"\"\"\n",
        "    dataset, num_samples = self._get_deterministic_dataset(split_name, for_eval, train_examples)\n",
        "      \n",
        "    # Cache the whole dataset if it's smaller than 150K examples.\n",
        "    if not for_eval and num_samples <= 150000:\n",
        "      dataset = dataset.cache()\n",
        "\n",
        "    # Repeats data `epochs` time or indefinitely if `epochs` is None.\n",
        "    if epochs is None or epochs > 1:\n",
        "      dataset = dataset.repeat(epochs)\n",
        "\n",
        "    if not for_eval and self._shuffle_buffer_size > 1:\n",
        "      dataset = dataset.shuffle(self._shuffle_buffer_size)\n",
        "\n",
        "    def prepare_example(example):\n",
        "      image_decoder = self.builder.info.features[self.image_key].decode_example\n",
        "      # Rename features to common names.\n",
        "      example = {\n",
        "          \"image\": image_decoder(example[self.image_key]),\n",
        "          \"label\": example[self.label_key],\n",
        "          \"filename\": example[self.filename_key],\n",
        "      }\n",
        "      if self.is_multilabel:\n",
        "        example[\"label\"] = tf.reduce_max(tf.one_hot(example[\"label\"],\n",
        "                                                    depth=self.num_classes,\n",
        "                                                    dtype=tf.int64), axis=0)\n",
        "      if preprocess_fn:\n",
        "        example = preprocess_fn(example)\n",
        "      return example\n",
        "\n",
        "    dataset = dataset.map(prepare_example, self._num_parallel_calls)\n",
        "    if self._ignore_errors:  # Ignore images with errors.\n",
        "      dataset = dataset.apply(tf.data.experimental.ignore_errors())\n",
        "    dataset = dataset.batch(batch_size, self._drop_remainder)\n",
        "    dataset = dataset.prefetch(self._prefetch)\n",
        "    return dataset\n",
        "\n",
        "def preprocess_fn(data, size=224, input_range=(0.0, 1.0)):\n",
        "  image = data[\"image\"]\n",
        "  image = tf.image.resize(image, [size, size])\n",
        "  image = tf.cast(image, tf.float32) / 255.0\n",
        "  image = image * (input_range[1] - input_range[0]) + input_range[0]\n",
        "  data[\"image\"] = image\n",
        "  return data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0O2eSq9EhwzA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def visualize(ds, figsize=(17, 17)):\n",
        "  batch_size = 16\n",
        "  train = ds.get_tf_data(\"val\", batch_size,  preprocess_fn=preprocess_fn)\n",
        "  xx = next(train.make_one_shot_iterator())\n",
        "  print(f\"Dataset: {ds.name}\")\n",
        "  print(\"Images: \", xx[\"image\"].shape, stats_str(xx[\"image\"]))\n",
        "  print(\"Labels: \", xx[\"label\"].shape, stats_str(xx[\"label\"]))\n",
        "  plt.figure(figsize=figsize)\n",
        "  for i in range(batch_size):\n",
        "    plt.subplot(4, 4, 1+i)\n",
        "    plt.imshow(xx[\"image\"][i])\n",
        "    if ds.is_multilabel:\n",
        "      labels = [ds.label_name_fn(lid) for lid, value in enumerate(xx[\"label\"][i]) if value]\n",
        "      plt.title(\"\\n\".join(labels))\n",
        "    else:\n",
        "      plt.title(ds.label_name_fn(xx[\"label\"][i]))\n",
        "  plt.show()\n",
        "\n",
        "def stats_str(arr, f=None, with_median=False, with_count=False):\n",
        "  \"\"\"Returns a string with main stats info about the given array.\n",
        "\n",
        "  By default, the string has the form: \"mean +/- standard_deviation [min..max]\"\n",
        "  values of the data array.\n",
        "\n",
        "  Args:\n",
        "    arr: array-like\n",
        "    f: str\n",
        "    with_median: boolean\n",
        "    with_count: boolean\n",
        "  Returns:\n",
        "    stats_str: str\n",
        "  \"\"\"\n",
        "  if arr is None or (isinstance(arr, (list, tuple)) and not arr):\n",
        "    return \"[empty]\"\n",
        "  if not isinstance(arr, np.ndarray):\n",
        "    try:\n",
        "      arr = arr.numpy()  # If arr is a TF-2 tensor.\n",
        "    except AttributeError:\n",
        "      pass\n",
        "    try:\n",
        "      arr = np.concatenate(arr).ravel()  # to deal with different length lists\n",
        "    except ValueError:\n",
        "      arr = np.array(arr)\n",
        "  if f is None:\n",
        "    f = \"{:.3f}\"\n",
        "  if with_median:\n",
        "    median = (\" median: \" + f).format(np.median(arr))\n",
        "  else:\n",
        "    median = \"\"\n",
        "  count = \" n: {:,}\".format(len(arr)) if with_count else \"\"\n",
        "  pm = \"+/-\"  # this one doesn't work: u' \\u00B1'\n",
        "  if arr.dtype.kind in [\"i\", \"u\"]:\n",
        "    return (f + pm + f + \" [{}..{}]\").format(arr.mean(), arr.std(), arr.min(),\n",
        "                                             arr.max()) + median + count\n",
        "  return (f + pm + f + \" [\" + f + \"..\" + f + \"]\").format(\n",
        "      arr.mean(), arr.std(), arr.min(), arr.max()) + median + count"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SlI5jYBTiRF1",
        "colab_type": "text"
      },
      "source": [
        "## Inspect a specific dataset\n",
        "\n",
        "Attention: many datasets take long to build and may required multiple days to download. So2Sat must be setup manually."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OvNv3aFbiGNg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Prepare and load a dataset\n",
        "POSSIBLE_DATASET_NAMES = [\"bigearthnet\", \"eurosat\", \"resisc45\", \"so2sat\", \"uc_merced\"]\n",
        "DATASET_NAME = \"uc_merced\"\n",
        "\n",
        "tfds.load(DATASET_NAME);  # This will trigger downloading and preparing the dataset.\n",
        "\n",
        "# Visualize examples\n",
        "dataset = TfdsDataset(DATASET_NAME)\n",
        "visualize(dataset)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vCca62f4iXfr",
        "colab_type": "text"
      },
      "source": [
        "## Dump the splits used in a dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SG7OkXLZiJaf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Dump the specific splits used on the dataset.\n",
        "dataset = TfdsDataset(DATASET_NAME)\n",
        "for split in [\"train\", \"val\", \"test\"]:\n",
        "  filenames = dataset.get_filenames(split)\n",
        "  output = f\"/tmp/{dataset.name}-{split}.txt\"\n",
        "  with open(output, \"wb\") as f:\n",
        "    f.write(b\"\\n\".join(filenames))\n",
        "  print(f\"Wrote: {output}\")"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}